#!/usr/bin/env python3
"""Send and recieve test packets"""
#
# Copyright (C) Hamish Coleman
# SPDX-License-Identifier: GPL-2.0-only

import argparse
import pprint
import socket
import struct


def hexdump(display_addr, data):
    size = len(data)

    row = 0
    while size > 0:
        print(f"{display_addr:03x}: ", end="")

        for i in range(16):
            if i < size:
                print(f"{data[row+i]:02x}", end="")
            else:
                print("  ", end="")

            if i == 7:
                print(" ", end="")
            print(" ", end="")

        print("  |", end="")

        for i in range(16):
            if i < size:
                ch = data[row+i]
                if ch >= 0x20 and ch <= 0x7e:
                    print(chr(ch), end="")
                else:
                    print(" ", end="")

        print("|")

        size -= 16
        display_addr += 16
        row += 16


class PacketBase():
    unstable_fields = {}
    names = [
        None,
        "REGISTER",             # 1
        "DEREGISTER",           # this packet never used in original code
        "PACKET",
        "REGISTER_ACK",
        "REGISTER_SUPER",       # 5
        "UNREGISTER_SUPER",
        "REGISTER_SUPER_ACK",
        "REGISTER_SUPER_NAK",
        "FEDERATION",
        "PEER_INFO",            # 10
        "QUERY_PEER",
        "RE_REGISTER_SUPER",
    ]

    @classmethod
    def id2type(cls, id):
        return cls.names[id]

    @classmethod
    def type2id(cls, name):
        return cls.names.index(name)

    defaults = {
        "proto_version": 3,
        "ttl": 2,
        "flags": 0,
        "community": b"test",
        "cookie": 0x1000,
        "srcMac": b"\x00\x00\x00\x00\x00\x01",
        "dstMac": b"\x00\x00\x00\x00\x00\x02",
        "edgeMac": b"\x00\x00\x00\x00\x00\x03",
        "targetMac": b"\x00\x00\x00\x00\x00\x04",
        "ipv4": b"\x00\x00\x00\x00",
        "masklen": 0,
        "desc": b"packetgenerator",
        "auth_scheme": 0,   # n2n_auth_none
        "auth_token_size": 0,
        "auth_token": b"",
        "key_time": 600,
        "aflags": 0,
        "compression": 1,   # COMPRESSION_ID_NONE
        "transform": 1,     # TRANSFORM_ID_NULL
        "packet": b"\x02\x00\x00\x00\x00\x11"
                  b"\x02\x00\x00\x00\x00\x22"
                  b"\x0f\x0f"
                  b"\x55\xaa",
        "sock_family": socket.AF_INET,
        "sock_port": 8000,
        "sock_addr": b"\x01\x02\x03\x04",
    }

    @classmethod
    def set_default(cls, key, val):
        cls.defaults[key] = val

    @classmethod
    def get_default(cls, key):
        return cls.defaults[key]

    def __init__(self):
        # Set the default values for all the field data
        self.data = self.defaults.copy()

    def encode_common(self, pkt_type=None):
        """Initialize the buffer with just the common header fields"""
        self.buffer = b""

        if pkt_type is not None:
            self.data["type"] = pkt_type

        if "type" not in self.data:
            raise ValueError("Unknown packet type")

        # special case for this one field
        type_flags = self.data["type"] & 0x1f
        type_flags |= self.data["flags"] & 0xffe0

        self.buffer += struct.pack(
            "!BBH20s",
            self.data["proto_version"],
            self.data["ttl"],
            type_flags,
            self.data["community"]
        )

    def decode_common(self, buffer):
        """Extract just the common fields from the buffer"""
        self.data = {}

        # Set the values in the expected sort order
        self.data["_name"] = "TBD"
        self.data["proto_version"] = "TBD"
        self.data["ttl"] = "TBD"
        self.data["type"] = "TBD"
        self.data["flags"] = "TBD"

        offset = self.unpack_fields(
                buffer,
                "!BBH20s",
                "proto_version",
                "ttl",
                "type",
                "community",
        )

        self.data["flags"] = self.data["type"] & 0xffe0
        pkt_type = self.data["type"] & 0x1f
        self.data["type"] = pkt_type

        try:
            self.data["_name"] = self.id2type(pkt_type)
        except IndexError:
            pass
        return offset

    def encode_sock(self, prefix="sock"):
        name_family = f"{prefix}_family"
        name_port = f"{prefix}_port"
        name_addr = f"{prefix}_addr"

        # TODO: sock->type == SOCK_STREAM
        if self.data[name_family] == socket.AF_INET:
            format = "!HH4s"
            sock_flag = 0
        elif self.data[name_family] == socket.AF_INET6:
            format = "!HH16s"
            sock_flag = 0x8000
        else:
            raise ValueError("Unknown sock_family")

        self.buffer += struct.pack(
            format,
            sock_flag,
            self.data[name_port],
            self.data[name_addr],
        )

    def decode_sock(self, buffer, prefix="sock"):
        name_family = f"{prefix}_family"
        name_type = f"{prefix}_type"
        name_port = f"{prefix}_port"
        name_addr = f"{prefix}_addr"

        # TODO: sock->type == SOCK_STREAM

        data = struct.unpack("!H", buffer[:2])
        family = data[0] & 0x8000
        if family == 0:
            self.data[name_family] = socket.AF_INET
            format = "!HH4s"
        elif family == 0x8000:
            self.data[name_family] = socket.AF_INET6
            format = "!HH16s"
        else:
            raise ValueError(f"Unknown sock_family {data[0]}")
        sock_type = data[0] &0x4000
        if sock_type == 0:
            self.data[name_type] = socket.SOCK_DGRAM
        else:
            self.data[name_type] = socket.SOCK_STREAM

        calcsize = struct.calcsize(format)
        buffer = buffer[:calcsize]

        data = struct.unpack(format, buffer)
        self.data[name_port] = data[1]
        self.data[name_addr] = data[2]
        return calcsize

    def encode(self):
        raise NotImplementedError

    def decode(self, buffer):
        raise NotImplementedError

    def stable_values(self):
        """Remove fields that will constantly change - allowing testing"""

        changed = False
        for field, val in self.unstable_fields.items():
            self.data[field] = val
            changed = True

        if changed and self.buffer is not None:
            self.encode()

    def flags_socket(self):
        return self.data["flags"] & 0x0040

    def pack_fields(self, format, *fields):
        fielddata = [self.data[f] for f in fields]
        self.buffer += struct.pack(format, *fielddata)

    def unpack_fields(self, buffer, format, *fields):
        calcsize = struct.calcsize(format)
        data = struct.unpack(format, buffer[:calcsize])

        for i in range(len(data)):
            key = fields[i]
            val = data[i]
            self.data[key] = val

        return calcsize

    def is_header_encrypted(self):
        if self.data["community"][19] != 0:
            # the community string is not null terminated
            return True
        if self.data["type"] > 12:
            # no packet types this large exist
            return True
        if self.data["flags"] >= 0x0080:
            # no flags exist beyond this
            return True
        return False


class PacketREGISTER(PacketBase):
    format1 = "!I6s6s"
    format2 = "!4sB16s"

    def encode(self):
        self.encode_common(1)
        self.pack_fields(self.format1, "cookie", "srcMac", "edgeMac")

        if self.flags_socket():
            self.encode_sock()

        self.pack_fields(self.format2, "ipv4", "masklen", "desc")

        return self.buffer

    def decode(self, buffer):
        offset = self.decode_common(buffer)
        offset += self.unpack_fields(
            buffer[offset:],
            self.format1,
            "cookie",
            "srcMac",
            "edgeMac",
        )

        if self.flags_socket():
            offset += self.decode_sock(buffer[offset:])

        self.unpack_fields(
            buffer[offset:],
            self.format2,
            "ipv4",
            "masklen",
            "desc",
        )

        self.buffer = buffer
        return self.data


class PacketPACKET(PacketBase):
    format1 = "!6s6s"
    format2 = "!BB"

    def encode(self):
        self.encode_common(3)
        self.pack_fields(self.format1, "srcMac", "edgeMac")

        if self.flags_socket():
            self.encode_sock()

        self.pack_fields(self.format2, "compression", "transform")

        self.buffer += struct.pack("64s", self.data["packet"])
        return self.buffer

    def decode(self, buffer):
        offset = self.decode_common(buffer)
        offset += self.unpack_fields(
                buffer[offset:],
                self.format1,
                "srcMac",
                "edgeMac",
        )

        if self.flags_socket():
            offset += self.decode_sock(buffer[offset:])

        offset += self.unpack_fields(
            buffer[offset:],
            self.format2,
            "compression",
            "transform",
        )

        self.data["packet"] = buffer[offset:]
        self.buffer = buffer
        return self.data


class PacketREGISTER_SUPER(PacketBase):
    format1 = "!I6s"
    format2 = "!4sB16sHH"
    format3 = "!L"

    def encode(self):
        self.encode_common(5)
        self.pack_fields(self.format1, "cookie", "edgeMac")

        if self.flags_socket():
            self.encode_sock()

        self.pack_fields(
            self.format2,
            "ipv4",
            "masklen",
            "desc",
            "auth_scheme",
            "auth_token_size",
        )

        if self.data["auth_token_size"]:
            size = self.data["auth_token_size"]
            token_format = f"{size}s"
            self.pack_fields(token_format, "auth_token")

        self.pack_fields(self.format3, "key_time")
        return self.buffer

    def decode(self, buffer):
        offset = self.decode_common(buffer)
        offset += self.unpack_fields(
                buffer[offset:],
                self.format1,
                "cookie",
                "edgeMac",
        )

        if self.flags_socket():
            offset += self.decode_sock(buffer[offset:])

        offset += self.unpack_fields(
            buffer[offset:],
            self.format2,
            "ipv4",
            "masklen",
            "desc",
            "auth_scheme",
            "auth_token_size",
        )

        if self.data["auth_token_size"]:
            size = self.data["auth_token_size"]
            self.data["auth_token"] = buffer[offset:offset+size]
            offset += size

        self.unpack_fields(
            buffer[offset:],
            self.format3,
            "key_time",
        )

        self.buffer = buffer
        return self.data


class PacketUNREGISTER_SUPER(PacketBase):
    format1 = "!HH"
    format2 = "!6s"

    def encode(self):
        self.encode_common(6)
        self.pack_fields(self.format1, "auth_scheme", "auth_token_size")

        if self.data["auth_token_size"]:
            size = self.data["auth_token_size"]
            token_format = f"{size}s"
            self.pack_fields(token_format, "auth_token")

        self.pack_fields(self.format2, "edgeMac")
        return self.buffer

    def decode(self, buffer):
        offset = self.decode_common(buffer)
        offset += self.unpack_fields(
            buffer[offset:],
            self.format1,
            "auth_scheme",
            "auth_token_size",
        )

        if self.data["auth_token_size"]:
            size = self.data["auth_token_size"]
            self.data["auth_token"] = buffer[offset:offset+size]
            offset += size

        self.unpack_fields(
            buffer[offset:],
            self.format2,
            "edgeMac",
        )

        self.buffer = buffer
        return self.data


class PacketREGISTER_SUPER_ACK(PacketBase):
    format1 = "!I6s4sBH"
    format2 = "!HH"
    format3 = "!B"
    format4 = "!L"

    def encode(self):
        self.encode_common(7)
        self.pack_fields(
            self.format1,
            "cookie",
            "srcMac",
            "ipv4",
            "masklen",
            "lifetime",
        )
        self.encode_sock()
        self.pack_fields(
            self.format2,
            "auth_scheme",
            "auth_token_size",
        )

        if self.data["auth_token_size"]:
            size = self.data["auth_token_size"]
            token_format = f"{size}s"
            self.pack_fields(token_format, "auth_token")

        self.pack_fields(self.format3, "num_sn")
        if self.data["num_sn"]:
            raise NotImplementedError

        self.pack_fields(self.format4, "key_time")

        return self.buffer

    def decode(self, buffer):
        offset = self.decode_common(buffer)
        offset += self.unpack_fields(
            buffer[offset:],
            self.format1,
            "cookie",
            "srcMac",
            "ipv4",
            "masklen",
            "lifetime",
        )
        offset += self.decode_sock(buffer[offset:])
        offset += self.unpack_fields(
            buffer[offset:],
            self.format2,
            "auth_scheme",
            "auth_token_size",
        )

        if self.data["auth_token_size"]:
            size = self.data["auth_token_size"]
            self.data["auth_token"] = buffer[offset:offset+size]
            offset += size

        offset += self.unpack_fields(
            buffer[offset:],
            self.format3,
            "num_sn",
        )

        if self.data["num_sn"]:
            raise NotImplementedError

        offset += self.unpack_fields(
            buffer[offset:],
            self.format4,
            "key_time",
        )

        self.buffer = buffer
        return self.data


class PacketPEER_INFO(PacketBase):
    format1 = "!H6s6s"
    format2 = "!LL20s"
    unstable_fields = {
        "load": 0xffffffff,
        "uptime": 0xffffffff,
        "version": b"_stabilised_",
    }

    def encode(self):
        self.encode_common(10)
        self.pack_fields(
            self.format1,
            "aflags",
            "srcMac",
            "mac",
        )
        self.encode_sock()

        if self.flags_socket():
            self.encode_sock(prefix="sock2")

        self.pack_fields(self.format2, "load", "uptime", "version")

        return self.buffer

    def decode(self, buffer):
        offset = self.decode_common(buffer)
        offset += self.unpack_fields(
            buffer[offset:],
            self.format1,
            "aflags",
            "srcMac",
            "mac",
        )
        offset += self.decode_sock(buffer[offset:])

        if self.flags_socket():
            offset += self.decode_sock(buffer[offset:], prefix="sock2")

        offset += self.unpack_fields(
            buffer[offset:],
            self.format2,
            "load",
            "uptime",
            "version",
        )

        self.buffer = buffer
        return self.data


class PacketQUERY_PEER(PacketBase):
    format1 = "!6s6sH"

    def encode(self):
        self.encode_common(11)
        self.pack_fields(self.format1, "srcMac", "edgeMac", "aflags")

        return self.buffer

    def decode(self, buffer):
        offset = self.decode_common(buffer)
        self.unpack_fields(
            buffer[offset:],
            self.format1,
            "srcMac",
            "edgeMac",
            "aflags",
        )

        self.buffer = buffer
        return self.data


class PacketQUERY_PEER_ping(PacketQUERY_PEER):
    def __init__(self):
        super().__init__()
        self.data["edgeMac"] = b"\x00\x00\x00\x00\x00\x00"


class PacketGeneric(PacketBase):
    specific = {
        1: PacketREGISTER,
        # 2: DEREGISTER, but never used
        3: PacketPACKET,
        # 4: PacketREGISTER_ACK,
        5: PacketREGISTER_SUPER,
        6: PacketUNREGISTER_SUPER,
        7: PacketREGISTER_SUPER_ACK,
        # 8: PacketREGISTER_SUPER_NAK,
        # 9: PacketFEDERATION,
        10: PacketPEER_INFO,
        11: PacketQUERY_PEER,
        # 12: PacketRE_REGISTER_SUPER,
    }

    scenarios = {
        "test_REGISTER": PacketREGISTER,
        "test_PACKET": PacketPACKET,
        "test_REGISTER_SUPER": PacketREGISTER_SUPER,
        "test_UNREGISTER_SUPER": PacketUNREGISTER_SUPER,
        "test_QUERY_PEER": PacketQUERY_PEER,
        "test_QUERY_PEER_ping": PacketQUERY_PEER_ping,
    }

    @classmethod
    def from_buffer(cls, buffer):
        obj = PacketGeneric()
        obj.decode_common(buffer)
        if obj.is_header_encrypted():
            raise NotImplementedError
        pkt_type = obj.data["type"]
        if pkt_type in cls.specific:
            obj = cls.specific[pkt_type]()

        obj.decode(buffer)
        return obj

    @classmethod
    def from_scenario(cls, name):
        if name not in cls.scenarios:
            raise ValueError("Unknown scenario")
        return cls.scenarios[name]()


# to supernode:
# - PACKET -> forwarded to another node
# - REGISTER -> forwarded
# - REGISTER_SUPER -> REGISTER_SUPER_NAK or REGISTER_SUPER_ACK and/or forwarded
# - UNREGISTER_SUPER -> deletes peer, no reply
# - QUERY_PEER -> PEER_INFO
# - PEER_INFO -> forwarded

# to edge:
# - PACKET -> processed to tun
# - REGISTER -> if !from_supernode, REGISTER_ACK
# - PEER_INFO -> REGISTER
# - RE_REGISTER_SUPER -> REGISTER_SUPER , multicast REGISTER


def main():
    ap = argparse.ArgumentParser(description=__doc__)
    ap.add_argument(
        "-s",
        "--server",
        help="host and port of server to contact",
        default="localhost:1968"
    )
    ap.add_argument(
        "--bind",
        help="bind local socket to port",
        type=int,
        default=False
    )
    ap.add_argument(
        "-c",
        "--community",
        help="The n3n community name",
        default="test"
    )
    ap.add_argument(
        "-t",
        "--timeout",
        type=int,
        help="how long to wait for replies",
        default=10
    )
    ap.add_argument(
        "--raw",
        action="store_true",
        help="Dont stablise the results - show the real data",
        default=False
    )
    ap.add_argument(
        "--tcp",
        action="store_true",
        help="Switch to TCP connection",
        default=False,
    )
    ap.add_argument("scenario", help="What packet to send")

    args = ap.parse_args()

    host, sep, port = args.server.partition(":")
    port = int(port)

    community = args.community.encode("utf8")
    PacketBase.set_default("community", community)

    if args.tcp:
        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR,1)
    else:
        sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)

    sock.settimeout(args.timeout)

    if args.bind:
        sock.bind(('', args.bind))

    if args.tcp:
        sock.connect((host, port))

    if args.scenario != "listen":
        send_pkt = PacketGeneric.from_scenario(args.scenario)
        buf = send_pkt.encode()
        print("test:")
        hexdump(0, buf)

        if args.tcp:
            # add framing
            header = struct.pack('>H', len(buf))
            sock.sendto(header, (host, port))

        sock.sendto(buf, (host, port))

    if args.timeout == 0:
        return

    if args.tcp:
        # fetch framing
        header = sock.recv(2, socket.MSG_WAITALL)
        size = struct.unpack('>H', header)[0]
        data = sock.recv(size, socket.MSG_WAITALL)
    else:
        data = sock.recv(1600)

    recv_pkt = PacketGeneric.from_buffer(data)
    recv_pkt.decode(data)

    if not args.raw:
        recv_pkt.stable_values()

    print("recv:")
    hexdump(0, recv_pkt.buffer)
    pprint.pp(recv_pkt.data)


if __name__ == '__main__':
    main()
